import sys
import inspect
from functools import wraps
import os.path as op
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
import matplotlib.patches as patches

def get_default_args(func):
    signature = inspect.signature(func)
    return {
        k: v.default
        for k, v in signature.parameters.items()
        if v.default is not inspect.Parameter.empty
    }

def check_outputs(**kwargs):
    """
    Decorator to check whether all output files are ready.
    If all files are already cached, then skipped the current stage.
    """
    def Inner(func):
        @wraps(func)
        def wrapper(**func_kwargs):
            stage_name = kwargs.get("stage_name", "-")
            data_route = func_kwargs.get("data_route", "")
            print("Info: start stage <%s>, data route -> %s"%(stage_name, data_route))
            assert ("out_files" in func_kwargs)
            out_files = func_kwargs.get("out_files")
            func_kwargs.update({"do": False})
            for k, v in out_files.items():
                if not op.exists(v):
                    func_kwargs.update({"do": True})
                    break
            if (not func_kwargs["do"]) and \
               (not kwargs.get("override", False)) and \
               (not func_kwargs.get("override", False)):
                print("Info: data ready, skipped.")
                return 
            else:
                return func(**func_kwargs)
        return wrapper
    return Inner


def check_inputs(**kwargs):
    """
    Decorator to check whether all input files are ready.
    """
    def Inner(func):


        @wraps(func)
        def wrapper(**kwds):
            func_kwargs = get_default_args(func)
            func_kwargs.update(kwds)
            inputs = kwargs.get("inputs", [])
            assert ("input_files" in func_kwargs) 
            for f in inputs:
                if f not in func_kwargs["input_files"]:
                    print("Error: need file <%s>."%(f))
                    sys.exit(0)

            errors = list()
            input_files = func_kwargs["input_files"]
            for k,v in input_files.items():
                if not op.exists(v):
                    errors.append("Error: input file <%s> is null."%(k))
            if errors:
                return errors
            else:
                return func(**func_kwargs)
        return wrapper
    return Inner

def get_dims_from_coordstxt(coords_txt):
    """
    Get latitude & and lontitude dimension from coords.txt
    Parameters:
        coods_txt: path for txt file
    Return:
        lat_dim, lon_dim, geo_size
    """
    lat_dim, lon_dim = 0, 0
    geo_size = ()
    with open(coords_txt, 'r') as f:
        for i, line in enumerate(f):
            if i == 2:
                assert "File Dimension:" in line
                lat_dim, lon_dim = [int(x) for x in line.split(":")[-1].split("x")]
                geo_size = (lon_dim, lat_dim)
            if i > 2:
                break
    return lat_dim, lon_dim, geo_size

def load_depth(depth_txt):
    """
    Load and parse the depth.txt to fetch geo coords along with the depth data.
    Parameters:
        depth_txt: path for depth.txt
    Return: 
        label_coords, label_depths
    """
    label_coords = np.loadtxt(depth_txt, dtype="float32", usecols=[0, 1])
    label_depths = np.loadtxt(depth_txt, dtype="float32", usecols=[2])
    return label_coords, label_depths

def load_npz(npz_file):
    """
    Load npz data, all entries follow the order you create.
    Parameters:
        npz_file: npz_file saved by `savez_compressed`
    Return:
        tupple for all data entries.
    """
    data = np.load(npz_file)
    res = list()
    for _data_entry in data.files:
        res.append(data[_data_entry])
    return tuple(res)

def dump_npz(output_file, **kwargs):
    """
    Dumpy data to npz to cache the result.
    Parameters:
        output_file: path for the file to dump
        kwargs: data entries with keys.
    Return:
        None
    Exceptions:
        Error happened during file dump.
    """

    try:
        np.savez_compressed(output_file, **kwargs)
    except:
        print("Error: npz file generated failed.")
        sys.exit(0)

def dump_labels(output_file,
                img_data,
                label_val,
                masks,
                y_inds,
                x_inds):
    """
    Plot image with patch areas highlighted for groundtruth, patch
    data are marked with yellow point and masks are drawn with red 
    rectangle.
    Parameters:
        output_file: path to dump
        img_data: the band data of hyspectral image
        label_val: ndarray, label value
        y_inds, x_inds: ndarray, indexes for y & x direction.
        masks: list, [[x_tl, y_tl, x_br, y_br]]
    """
    plt.figure("Labels")
    plt.imshow(img_data, cmap=plt.get_cmap("twilight_shifted"))
    plt.scatter(x_inds, y_inds, color='yellow', s=0.5, marker="D")
    for mask in masks:    
        x1, y1, x2, y2 = mask
        rect = patches.Rectangle((x1,y1),x2-x1,y2-y1,
                                 linewidth=1,edgecolor='r',
                                 facecolor='none')
        plt.gca().add_patch(rect)
    plt.savefig(output_file)
    plt.close()

def clip_inds(y_inds, x_inds, 
              dim_y, dim_x,
              depths, masks):

    """
    Clip inds out of the image and filter out the ones within masks.
    Parameters: 
        y_inds, x_inds: indexes for y & x axis
        dim_y, dim_x: scope for y & x axis
        depths: labels for one position
        masks: list of masks for the filter
    Return:
        y_inds, x_inds, depths: indexes and labels post filtering
    """
    y_inds = np.array(y_inds).clip(0, dim_y-1)
    x_inds = np.array(x_inds).clip(0, dim_x-1)
    depths = np.array(depths)
    if y_inds.shape[0] > 0:
        for mask in masks:
            x_min, y_min, x_max, y_max = mask
            y_inds_flatten = y_inds.reshape(y_inds.shape[0], -1).max(axis=1)
            x_inds_flatten = x_inds.reshape(x_inds.shape[0], -1).max(axis=1)
            inds = np.where((y_inds_flatten > y_max) | (y_inds_flatten < y_min) |
                            (x_inds_flatten > x_max) | (x_inds_flatten < x_min))
            y_inds = y_inds[inds]
            x_inds = x_inds[inds]
            depths = depths[inds]
    return y_inds, x_inds, depths


